{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Introduction to SimpleITKv4 Registration <a href=\"https://mybinder.org/v2/gh/InsightSoftwareConsortium/SimpleITK-Notebooks/master?filepath=Python%2F60_Registration_Introduction.ipynb\"><img style=\"float: right;\" src=\"https://mybinder.org/badge_logo.svg\"></a>\n",
    "\n",
    "First, a reminder of some SimpleITK conventions:\n",
    "* The dimensionality of images being registered is required to be the same (i.e. 2D/2D or 3D/3D).\n",
    "* The pixel type of images being registered is required to be the same.\n",
    "* Supported pixel types for images being registered are sitkFloat32 and sitkFloat64. You can use the SimpleITK [Cast](http://www.simpleitk.org/doxygen/latest/html/namespaceitk_1_1simple.html#af8c9d7cc96a299a05890e9c3db911885) function if your image's pixel type is something else.\n",
    "\n",
    "## Registration Components \n",
    "\n",
    "<img src=\"ITKv4RegistrationComponentsDiagram.svg\"/><br><br>\n",
    "\n",
    "There are many options for creating an instance of the registration framework, all of which are configured in SimpleITK via methods of the <a href=\"http://www.simpleitk.org/doxygen/latest/html/classitk_1_1simple_1_1ImageRegistrationMethod.html\">ImageRegistrationMethod</a> class. This class encapsulates many of the components available in ITK for constructing a registration instance.\n",
    "\n",
    "Here is a list of currently available choices for each group of components:\n",
    "\n",
    "### Optimizers\n",
    "\n",
    "Several optimizer types are supported via the `SetOptimizerAsX()` methods. These include:\n",
    "* [Exhaustive](http://www.itk.org/Doxygen/html/classitk_1_1ExhaustiveOptimizerv4.html)\n",
    "* [Nelder-Mead downhill simplex](http://www.itk.org/Doxygen/html/classitk_1_1AmoebaOptimizerv4.html), a.k.a. Amoeba\n",
    "* [Powell optimizer](https://itk.org/Doxygen/html/classitk_1_1PowellOptimizerv4.html)\n",
    "* [1+1 evolutionary optimizer](https://itk.org/Doxygen/html/classitk_1_1OnePlusOneEvolutionaryOptimizerv4.html)\n",
    "* Variations on gradient descent:\n",
    "  * [GradientDescent](http://www.itk.org/Doxygen/html/classitk_1_1GradientDescentOptimizerv4Template.html)\n",
    "  * [GradientDescentLineSearch](http://www.itk.org/Doxygen/html/classitk_1_1GradientDescentLineSearchOptimizerv4Template.html)\n",
    "  * [RegularStepGradientDescent](http://www.itk.org/Doxygen/html/classitk_1_1RegularStepGradientDescentOptimizerv4.html)\n",
    "* [ConjugateGradientLineSearch](http://www.itk.org/Doxygen/html/classitk_1_1ConjugateGradientLineSearchOptimizerv4Template.html)\n",
    "* [L-BFGS2](https://itk.org/Doxygen/html/classitk_1_1LBFGS2Optimizerv4.html) (Limited memory Broyden, Fletcher, Goldfarb, Shannon)\n",
    "* [L-BFGS-B](http://www.itk.org/Doxygen/html/classitk_1_1LBFGSBOptimizerv4.html) (Limited memory Broyden, Fletcher, Goldfarb, Shannon - Bound Constrained) - supports the use of simple constraints.\n",
    "\n",
    "\n",
    "### Similarity metrics\n",
    "\n",
    "Several metric types are supported via the `SetMetricAsX()` methods. These include:\n",
    "* [MeanSquares](http://www.itk.org/Doxygen/html/classitk_1_1MeanSquaresImageToImageMetricv4.html)\n",
    "* [Demons](http://www.itk.org/Doxygen/html/classitk_1_1DemonsImageToImageMetricv4.html)\n",
    "* [Correlation](http://www.itk.org/Doxygen/html/classitk_1_1CorrelationImageToImageMetricv4.html)\n",
    "* [ANTSNeighborhoodCorrelation](http://www.itk.org/Doxygen/html/classitk_1_1ANTSNeighborhoodCorrelationImageToImageMetricv4.html)\n",
    "* [JointHistogramMutualInformation](http://www.itk.org/Doxygen/html/classitk_1_1JointHistogramMutualInformationImageToImageMetricv4.html)\n",
    "* [MattesMutualInformation](http://www.itk.org/Doxygen/html/classitk_1_1MattesMutualInformationImageToImageMetricv4.html)\n",
    "\n",
    "\n",
    "### Interpolators\n",
    "\n",
    "Several interpolators are supported via the `SetInterpolator()` method, which can take one of\n",
    "the following interpolators:\n",
    "* sitkNearestNeighbor \n",
    "* sitkLinear\n",
    "* sitkBSpline1, sitkBSpline2, sitkBSpline3, sitkBSpline4, sitkBSpline5, where the number denotes the spline order, with sitkBSpline3 being the go-to option.\n",
    "* sitkGaussian\n",
    "* sitkHammingWindowedSinc\n",
    "* sitkCosineWindowedSinc\n",
    "* sitkWelchWindowedSinc\n",
    "* sitkLanczosWindowedSinc\n",
    "* sitkBlackmanWindowedSinc\n",
    "\n",
    "See the [list of interpolators in the API reference](http://www.simpleitk.org/doxygen/latest/html/namespaceitk_1_1simple.html#a7cb1ef8bd02c669c02ea2f9f5aa374e5) for more information\n",
    "\n",
    "\n",
    "## Data -  Retrospective Image Registration Evaluation\n",
    "\n",
    "In this notebook we will be using part of the training data from the Retrospective Image Registration Evaluation ([RIRE](https://rire.insight-journal.org/)) project."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import SimpleITK as sitk\n",
    "\n",
    "# Utility method that either downloads data from the Girder repository or\n",
    "# if already downloaded returns the file name for reading from disk (cached data).\n",
    "%run update_path_to_download_script\n",
    "from downloaddata import fetch_data as fdata\n",
    "\n",
    "# Always write output to a separate directory, we don't want to pollute the source directory.\n",
    "import os\n",
    "\n",
    "OUTPUT_DIR = \"Output\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Utility functions\n",
    "First we define a number of utility callback functions for image display and for plotting the similarity metric during registration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from ipywidgets import interact, fixed\n",
    "from IPython.display import clear_output\n",
    "\n",
    "\n",
    "# Callback invoked by the IPython interact method for scrolling through image stacks of\n",
    "# the two images being registered.\n",
    "def display_images(fixed_image_z, moving_image_z, fixed_npa, moving_npa):\n",
    "    # Create a figure with two subplots and the specified size.\n",
    "    plt.subplots(1, 2, figsize=(10, 8))\n",
    "\n",
    "    # Draw the fixed image in the first subplot.\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.imshow(fixed_npa[fixed_image_z, :, :], cmap=plt.cm.Greys_r)\n",
    "    plt.title(\"fixed image\")\n",
    "    plt.axis(\"off\")\n",
    "\n",
    "    # Draw the moving image in the second subplot.\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.imshow(moving_npa[moving_image_z, :, :], cmap=plt.cm.Greys_r)\n",
    "    plt.title(\"moving image\")\n",
    "    plt.axis(\"off\")\n",
    "\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# Callback invoked by the IPython interact method for scrolling and modifying the alpha blending\n",
    "# of an image stack of two images that occupy the same physical space.\n",
    "def display_images_with_alpha(image_z, alpha, fixed, moving):\n",
    "    img = (1.0 - alpha) * fixed[:, :, image_z] + alpha * moving[:, :, image_z]\n",
    "    plt.imshow(sitk.GetArrayViewFromImage(img), cmap=plt.cm.Greys_r)\n",
    "    plt.axis(\"off\")\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# Callback invoked when the StartEvent happens, sets up our new data.\n",
    "def start_plot():\n",
    "    global metric_values, multires_iterations\n",
    "\n",
    "    metric_values = []\n",
    "    multires_iterations = []\n",
    "\n",
    "\n",
    "# Callback invoked when the EndEvent happens, do cleanup of data and figure.\n",
    "def end_plot():\n",
    "    global metric_values, multires_iterations\n",
    "\n",
    "    del metric_values\n",
    "    del multires_iterations\n",
    "    # Close figure, we don't want to get a duplicate of the plot latter on.\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "# Callback invoked when the IterationEvent happens, update our data and display new figure.\n",
    "def plot_values(registration_method):\n",
    "    global metric_values, multires_iterations\n",
    "\n",
    "    metric_values.append(registration_method.GetMetricValue())\n",
    "    # Clear the output area (wait=True, to reduce flickering), and plot current data\n",
    "    clear_output(wait=True)\n",
    "    # Plot the similarity metric values\n",
    "    plt.plot(metric_values, \"r\")\n",
    "    plt.plot(\n",
    "        multires_iterations,\n",
    "        [metric_values[index] for index in multires_iterations],\n",
    "        \"b*\",\n",
    "    )\n",
    "    plt.xlabel(\"Iteration Number\", fontsize=12)\n",
    "    plt.ylabel(\"Metric Value\", fontsize=12)\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# Callback invoked when the sitkMultiResolutionIterationEvent happens, update the index into the\n",
    "# metric_values list.\n",
    "def update_multires_iterations():\n",
    "    global metric_values, multires_iterations\n",
    "    multires_iterations.append(len(metric_values))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read images\n",
    "\n",
    "Now we can read the images, casting the pixel type to that required for registration (Float32 or Float64) and have a look at them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fixed_image = sitk.ReadImage(fdata(\"training_001_ct.mha\"), sitk.sitkFloat32)\n",
    "moving_image = sitk.ReadImage(fdata(\"training_001_mr_T1.mha\"), sitk.sitkFloat32)\n",
    "\n",
    "interact(\n",
    "    display_images,\n",
    "    fixed_image_z=(0, fixed_image.GetSize()[2] - 1),\n",
    "    moving_image_z=(0, moving_image.GetSize()[2] - 1),\n",
    "    fixed_npa=fixed(sitk.GetArrayViewFromImage(fixed_image)),\n",
    "    moving_npa=fixed(sitk.GetArrayViewFromImage(moving_image)),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initial Alignment\n",
    "\n",
    "To set sensible default values for the center of rotation and initial translation between the two images, the CenteredTransformInitializer is used to align the centers of the two volumes and set the center of rotation to the center of the fixed image."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_transform = sitk.CenteredTransformInitializer(\n",
    "    fixed_image,\n",
    "    moving_image,\n",
    "    sitk.Euler3DTransform(),\n",
    "    sitk.CenteredTransformInitializerFilter.GEOMETRY,\n",
    ")\n",
    "\n",
    "moving_resampled = sitk.Resample(\n",
    "    moving_image,\n",
    "    fixed_image,\n",
    "    initial_transform,\n",
    "    sitk.sitkLinear,\n",
    "    0.0,\n",
    "    moving_image.GetPixelID(),\n",
    ")\n",
    "\n",
    "interact(\n",
    "    display_images_with_alpha,\n",
    "    image_z=(0, fixed_image.GetSize()[2] - 1),\n",
    "    alpha=(0.0, 1.0, 0.05),\n",
    "    fixed=fixed(fixed_image),\n",
    "    moving=fixed(moving_resampled),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Registration\n",
    "\n",
    "The specific registration task at hand estimates a 3D rigid transformation between images of different modalities. There are multiple components from each group (optimizers, similarity metrics, interpolators) that are appropriate for the task. Note that selecting a component from each group requires setting some parameter values. We have made the following choices:\n",
    "\n",
    "- Similarity metric: Mutual information (Mattes MI)\n",
    "    - Number of histogram bins = 50\n",
    "    - Sampling strategy = random\n",
    "    - Sampling percentage = 1%\n",
    "- Interpolator: sitkLinear\n",
    "- Optimizer: gradient descent\n",
    "    - Learning rate, step size along traversal direction in parameter space: 1.0\n",
    "    - Number of iterations, maximum number of iterations: 100\n",
    "    - Convergence minimum value, value used for convergence checking in conjunction with the energy profile of the similarity metric that is estimated in the given window size: 1e-6\n",
    "    - Convergence window size, number of values of the similarity metric which are used to estimate the energy profile of the similarity metric: 10\n",
    "\n",
    "\n",
    "In the next code block we perform registration using the settings given above, and take advantage of the built in multi-resolution framework, using a three tier pyramid.  \n",
    "\n",
    "In this example we plot the similarity metric value during registration. Note that the change of scales in the multi-resolution framework is readily visible."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "registration_method = sitk.ImageRegistrationMethod()\n",
    "\n",
    "# Similarity metric settings.\n",
    "registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)\n",
    "registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)\n",
    "registration_method.SetMetricSamplingPercentage(0.01)\n",
    "\n",
    "registration_method.SetInterpolator(sitk.sitkLinear)\n",
    "\n",
    "# Optimizer settings.\n",
    "registration_method.SetOptimizerAsGradientDescent(\n",
    "    learningRate=1.0,\n",
    "    numberOfIterations=100,\n",
    "    convergenceMinimumValue=1e-6,\n",
    "    convergenceWindowSize=10,\n",
    ")\n",
    "registration_method.SetOptimizerScalesFromPhysicalShift()\n",
    "\n",
    "# Setup for the multi-resolution framework.\n",
    "registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])\n",
    "registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])\n",
    "registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()\n",
    "\n",
    "# Don't optimize in-place, we would possibly like to run this cell multiple times.\n",
    "registration_method.SetInitialTransform(initial_transform, inPlace=False)\n",
    "\n",
    "# Connect all of the observers so that we can perform plotting during registration.\n",
    "registration_method.AddCommand(sitk.sitkStartEvent, start_plot)\n",
    "registration_method.AddCommand(sitk.sitkEndEvent, end_plot)\n",
    "registration_method.AddCommand(\n",
    "    sitk.sitkMultiResolutionIterationEvent, update_multires_iterations\n",
    ")\n",
    "registration_method.AddCommand(\n",
    "    sitk.sitkIterationEvent, lambda: plot_values(registration_method)\n",
    ")\n",
    "\n",
    "final_transform = registration_method.Execute(\n",
    "    sitk.Cast(fixed_image, sitk.sitkFloat32), sitk.Cast(moving_image, sitk.sitkFloat32)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Post registration analysis"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Having finished registration, we next query the registration method to see the metric value and the reason the optimization terminated. \n",
    "\n",
    "The metric value allows us to compare multiple registration runs as there is a probabilistic aspect to our registration. This is due to the random sampling used to estimate the similarity metric.\n",
    "\n",
    "Always remember to query why the optimizer terminated. This will help you understand whether termination is too early, either due to thresholds being too tight, early termination due to small number of iterations - numberOfIterations, or too loose, early termination due to large value for minimal change in similarity measure - convergenceMinimumValue."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Final metric value: {registration_method.GetMetricValue()}\")\n",
    "print(\n",
    "    f\"Optimizer's stopping condition, {registration_method.GetOptimizerStopConditionDescription()}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now visually inspect the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "moving_resampled = sitk.Resample(\n",
    "    moving_image,\n",
    "    fixed_image,\n",
    "    final_transform,\n",
    "    sitk.sitkLinear,\n",
    "    0.0,\n",
    "    moving_image.GetPixelID(),\n",
    ")\n",
    "\n",
    "interact(\n",
    "    display_images_with_alpha,\n",
    "    image_z=(0, fixed_image.GetSize()[2] - 1),\n",
    "    alpha=(0.0, 1.0, 0.05),\n",
    "    fixed=fixed(fixed_image),\n",
    "    moving=fixed(moving_resampled),\n",
    ");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we are satisfied with the results, save them to file."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sitk.WriteImage(\n",
    "    moving_resampled, os.path.join(OUTPUT_DIR, \"RIRE_training_001_mr_T1_resampled.mha\")\n",
    ")\n",
    "sitk.WriteTransform(\n",
    "    final_transform, os.path.join(OUTPUT_DIR, \"RIRE_training_001_CT_2_mr_T1.tfm\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the cell above we separately saved (1) a resampled version of the moving image and (2) the final registration transformation. Another option is to forgo the resampling and combine the transformation into the original moving image, \"physically\" moving it in space. This operation modifies the image's origin and direction cosine matrix but leaves the intensity information as is, image size and spacing don't change. In Slicer nomenclature this is referred to as \"transform hardening\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transformed_moving = sitk.TransformGeometry(moving_image, final_transform)\n",
    "print(\n",
    "    f\"origin before: {moving_image.GetOrigin()}\\norigin after: {transformed_moving.GetOrigin()}\"\n",
    ")\n",
    "print(\n",
    "    f\"direction cosine before: {moving_image.GetDirection()}\\ndirection cosine after: {transformed_moving.GetDirection()}\"\n",
    ")\n",
    "sitk.WriteImage(\n",
    "    transformed_moving,\n",
    "    os.path.join(OUTPUT_DIR, \"RIRE_training_001_mr_T1_transformed.mha\"),\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
